{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "from keras.models import Model\n",
    "from keras.layers import Input\n",
    "from keras.legacy.layers import Highway\n",
    "from keras import backend as K\n",
    "import json\n",
    "from collections import OrderedDict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def format_decimal(arr, places=6):\n",
    "    return [round(x * 10**places) / 10**places for x in arr]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "DATA = OrderedDict()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Highway"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**[legacy.Highway.0]**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/leon/miniconda3/lib/python3.5/site-packages/keras/legacy/layers.py:654: UserWarning: The `Highway` layer is deprecated and will be removed after 06/2017.\n",
      "  warnings.warn('The `Highway` layer is deprecated '\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "W shape: (6, 6)\n",
      "W: [0.176262, 0.795427, 0.783061, 0.631675, -0.928221, 0.383515, -0.242638, 0.037022, 0.315903, -0.6123, -0.455367, 0.437212, 0.566007, 0.700655, 0.55049, -0.926671, -0.766613, 0.502561, -0.521564, -0.490388, 0.715251, 0.899558, 0.123374, -0.642439, 0.540504, -0.015238, 0.262506, 0.678996, -0.077921, -0.00412, 0.358822, 0.301572, -0.46241, -0.865351, 0.54289, -0.038032]\n",
      "W_carry shape: (6, 6)\n",
      "W_carry: [-0.90255, -0.421781, 0.441933, -0.956768, -0.588154, -0.898453, -0.395456, 0.327821, -0.383771, 0.167183, -0.860858, 0.734809, -0.733519, -0.643751, -0.008141, 0.727399, 0.517888, 0.94097, 0.518605, -0.2315, -0.182563, 0.426721, -0.45866, 0.708206, 0.826328, 0.521512, 0.033337, -0.664328, -0.402615, -0.432114, -0.345575, -0.079719, 0.088737, -0.565198, 0.599737, 0.453369]\n",
      "b shape: (6,)\n",
      "b: [-0.583079, -0.036638, -0.158924, 0.718364, -0.657677, -0.322272]\n",
      "b_carry shape: (6,)\n",
      "b_carry: [0.034596, 0.893925, 0.53092, -0.435208, -0.557909, 0.372444]\n",
      "\n",
      "in shape: (6,)\n",
      "in: [-0.665722, -0.215115, 0.236105, -0.17614, -0.99507, 0.768064]\n",
      "out shape: (6,)\n",
      "out: [-0.6724, -0.131125, -0.713841, -0.86541, 0.009815, -0.272556]\n"
     ]
    }
   ],
   "source": [
    "data_in_shape = (6,)\n",
    "layer_0 = Input(shape=data_in_shape)\n",
    "layer_1 = Highway(activation='linear', bias=True)(layer_0)\n",
    "model = Model(inputs=layer_0, outputs=layer_1)\n",
    "\n",
    "# set weights to random (use seed for reproducibility)\n",
    "weights = []\n",
    "for i, w in enumerate(model.get_weights()):\n",
    "    np.random.seed(20+i)\n",
    "    weights.append(2 * np.random.random(w.shape) - 1)\n",
    "model.set_weights(weights)\n",
    "print('W shape:', weights[0].shape)\n",
    "print('W:', format_decimal(weights[0].ravel().tolist()))\n",
    "print('W_carry shape:', weights[1].shape)\n",
    "print('W_carry:', format_decimal(weights[1].ravel().tolist()))\n",
    "print('b shape:', weights[2].shape)\n",
    "print('b:', format_decimal(weights[2].ravel().tolist()))\n",
    "print('b_carry shape:', weights[3].shape)\n",
    "print('b_carry:', format_decimal(weights[3].ravel().tolist()))\n",
    "\n",
    "data_in = 2 * np.random.random(data_in_shape) - 1\n",
    "result = model.predict(np.array([data_in]))\n",
    "data_out_shape = result[0].shape\n",
    "data_in_formatted = format_decimal(data_in.ravel().tolist())\n",
    "data_out_formatted = format_decimal(result[0].ravel().tolist())\n",
    "print('')\n",
    "print('in shape:', data_in_shape)\n",
    "print('in:', data_in_formatted)\n",
    "print('out shape:', data_out_shape)\n",
    "print('out:', data_out_formatted)\n",
    "\n",
    "DATA['legacy.Highway.0'] = {\n",
    "    'input': {'data': data_in_formatted, 'shape': data_in_shape},\n",
    "    'weights': [{'data': format_decimal(w.ravel().tolist()), 'shape': w.shape} for w in weights],\n",
    "    'expected': {'data': data_out_formatted, 'shape': data_out_shape}\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**[legacy.Highway.1]**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "W shape: (5, 5)\n",
      "W: [0.288287, -0.238503, 0.326096, -0.672699, 0.925216, -0.306676, 0.983502, -0.529884, 0.171389, -0.18662, -0.727531, 0.088273, 0.036353, 0.53371, 0.8677, -0.820593, -0.608457, 0.988387, -0.529639, -0.522027, 0.2582, 0.469905, 0.376689, -0.937739, 0.805028]\n",
      "W_carry shape: (5, 5)\n",
      "W_carry: [-0.427892, 0.916211, 0.540626, 0.97374, -0.583669, -0.726166, 0.816748, -0.862723, -0.849335, 0.087069, -0.821201, -0.235213, 0.337121, -0.141662, -0.912087, -0.611428, -0.106681, -0.874853, -0.404865, 0.887262, -0.434378, -0.464613, -0.185562, 0.651981, 0.013401]\n",
      "b shape: (5,)\n",
      "b: [0.717779, -0.254578, 0.110258, 0.911313, 0.473339]\n",
      "b_carry shape: (5,)\n",
      "b_carry: [-0.50298, -0.100049, -0.178118, -0.479401, 0.740791]\n",
      "\n",
      "in shape: (5,)\n",
      "in: [-0.62992, -0.960677, 0.906504, 0.360902, -0.026824]\n",
      "out shape: (5,)\n",
      "out: [-0.44031, -0.941138, 0.764809, 0.556869, 0.346693]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/leon/miniconda3/lib/python3.5/site-packages/keras/legacy/layers.py:654: UserWarning: The `Highway` layer is deprecated and will be removed after 06/2017.\n",
      "  warnings.warn('The `Highway` layer is deprecated '\n"
     ]
    }
   ],
   "source": [
    "data_in_shape = (5,)\n",
    "layer_0 = Input(shape=data_in_shape)\n",
    "layer_1 = Highway(activation='tanh', bias=True)(layer_0)\n",
    "model = Model(inputs=layer_0, outputs=layer_1)\n",
    "\n",
    "# set weights to random (use seed for reproducibility)\n",
    "weights = []\n",
    "for i, w in enumerate(model.get_weights()):\n",
    "    np.random.seed(30+i)\n",
    "    weights.append(2 * np.random.random(w.shape) - 1)\n",
    "model.set_weights(weights)\n",
    "print('W shape:', weights[0].shape)\n",
    "print('W:', format_decimal(weights[0].ravel().tolist()))\n",
    "print('W_carry shape:', weights[1].shape)\n",
    "print('W_carry:', format_decimal(weights[1].ravel().tolist()))\n",
    "print('b shape:', weights[2].shape)\n",
    "print('b:', format_decimal(weights[2].ravel().tolist()))\n",
    "print('b_carry shape:', weights[3].shape)\n",
    "print('b_carry:', format_decimal(weights[3].ravel().tolist()))\n",
    "\n",
    "data_in = 2 * np.random.random(data_in_shape) - 1\n",
    "result = model.predict(np.array([data_in]))\n",
    "data_out_shape = result[0].shape\n",
    "data_in_formatted = format_decimal(data_in.ravel().tolist())\n",
    "data_out_formatted = format_decimal(result[0].ravel().tolist())\n",
    "print('')\n",
    "print('in shape:', data_in_shape)\n",
    "print('in:', data_in_formatted)\n",
    "print('out shape:', data_out_shape)\n",
    "print('out:', data_out_formatted)\n",
    "\n",
    "DATA['legacy.Highway.1'] = {\n",
    "    'input': {'data': data_in_formatted, 'shape': data_in_shape},\n",
    "    'weights': [{'data': format_decimal(w.ravel().tolist()), 'shape': w.shape} for w in weights],\n",
    "    'expected': {'data': data_out_formatted, 'shape': data_out_shape}\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "W shape: (4, 4)\n",
      "W: [-0.184626, -0.889268, 0.57707, -0.42539, -0.099299, -0.392175, 0.052799, 0.247624, 0.553551, 0.372483, 0.961878, 0.201632, 0.627937, 0.41729, -0.944931, 0.808534]\n",
      "W_carry shape: (4, 4)\n",
      "W_carry: [-0.498153, -0.907808, 0.353632, -0.913061, -0.767153, 0.207731, -0.618139, 0.337031, 0.834896, -0.16244, -0.33548, -0.433933, -0.627435, -0.365779, -0.037663, -0.860959]\n",
      "\n",
      "in shape: (4,)\n",
      "in: [0.409965, -0.370646, 0.490565, -0.203574]\n",
      "out shape: (4,)\n",
      "out: [0.482075, -0.04199, 0.593448, 0.031503]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/leon/miniconda3/lib/python3.5/site-packages/keras/legacy/layers.py:654: UserWarning: The `Highway` layer is deprecated and will be removed after 06/2017.\n",
      "  warnings.warn('The `Highway` layer is deprecated '\n"
     ]
    }
   ],
   "source": [
    "data_in_shape = (4,)\n",
    "layer_0 = Input(shape=data_in_shape)\n",
    "layer_1 = Highway(activation='hard_sigmoid', bias=False)(layer_0)\n",
    "model = Model(inputs=layer_0, outputs=layer_1)\n",
    "\n",
    "# set weights to random (use seed for reproducibility)\n",
    "weights = []\n",
    "for i, w in enumerate(model.get_weights()):\n",
    "    np.random.seed(40+i)\n",
    "    weights.append(2 * np.random.random(w.shape) - 1)\n",
    "model.set_weights(weights)\n",
    "print('W shape:', weights[0].shape)\n",
    "print('W:', format_decimal(weights[0].ravel().tolist()))\n",
    "print('W_carry shape:', weights[1].shape)\n",
    "print('W_carry:', format_decimal(weights[1].ravel().tolist()))\n",
    "\n",
    "data_in = 2 * np.random.random(data_in_shape) - 1\n",
    "result = model.predict(np.array([data_in]))\n",
    "data_out_shape = result[0].shape\n",
    "data_in_formatted = format_decimal(data_in.ravel().tolist())\n",
    "data_out_formatted = format_decimal(result[0].ravel().tolist())\n",
    "print('')\n",
    "print('in shape:', data_in_shape)\n",
    "print('in:', data_in_formatted)\n",
    "print('out shape:', data_out_shape)\n",
    "print('out:', data_out_formatted)\n",
    "\n",
    "DATA['legacy.Highway.2'] = {\n",
    "    'input': {'data': data_in_formatted, 'shape': data_in_shape},\n",
    "    'weights': [{'data': format_decimal(w.ravel().tolist()), 'shape': w.shape} for w in weights],\n",
    "    'expected': {'data': data_out_formatted, 'shape': data_out_shape}\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### export for Keras.js tests"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{\"legacy.Highway.0\": {\"weights\": [{\"shape\": [6, 6], \"data\": [0.176262, 0.795427, 0.783061, 0.631675, -0.928221, 0.383515, -0.242638, 0.037022, 0.315903, -0.6123, -0.455367, 0.437212, 0.566007, 0.700655, 0.55049, -0.926671, -0.766613, 0.502561, -0.521564, -0.490388, 0.715251, 0.899558, 0.123374, -0.642439, 0.540504, -0.015238, 0.262506, 0.678996, -0.077921, -0.00412, 0.358822, 0.301572, -0.46241, -0.865351, 0.54289, -0.038032]}, {\"shape\": [6, 6], \"data\": [-0.90255, -0.421781, 0.441933, -0.956768, -0.588154, -0.898453, -0.395456, 0.327821, -0.383771, 0.167183, -0.860858, 0.734809, -0.733519, -0.643751, -0.008141, 0.727399, 0.517888, 0.94097, 0.518605, -0.2315, -0.182563, 0.426721, -0.45866, 0.708206, 0.826328, 0.521512, 0.033337, -0.664328, -0.402615, -0.432114, -0.345575, -0.079719, 0.088737, -0.565198, 0.599737, 0.453369]}, {\"shape\": [6], \"data\": [-0.583079, -0.036638, -0.158924, 0.718364, -0.657677, -0.322272]}, {\"shape\": [6], \"data\": [0.034596, 0.893925, 0.53092, -0.435208, -0.557909, 0.372444]}], \"expected\": {\"shape\": [6], \"data\": [-0.6724, -0.131125, -0.713841, -0.86541, 0.009815, -0.272556]}, \"input\": {\"shape\": [6], \"data\": [-0.665722, -0.215115, 0.236105, -0.17614, -0.99507, 0.768064]}}, \"legacy.Highway.1\": {\"weights\": [{\"shape\": [5, 5], \"data\": [0.288287, -0.238503, 0.326096, -0.672699, 0.925216, -0.306676, 0.983502, -0.529884, 0.171389, -0.18662, -0.727531, 0.088273, 0.036353, 0.53371, 0.8677, -0.820593, -0.608457, 0.988387, -0.529639, -0.522027, 0.2582, 0.469905, 0.376689, -0.937739, 0.805028]}, {\"shape\": [5, 5], \"data\": [-0.427892, 0.916211, 0.540626, 0.97374, -0.583669, -0.726166, 0.816748, -0.862723, -0.849335, 0.087069, -0.821201, -0.235213, 0.337121, -0.141662, -0.912087, -0.611428, -0.106681, -0.874853, -0.404865, 0.887262, -0.434378, -0.464613, -0.185562, 0.651981, 0.013401]}, {\"shape\": [5], \"data\": [0.717779, -0.254578, 0.110258, 0.911313, 0.473339]}, {\"shape\": [5], \"data\": [-0.50298, -0.100049, -0.178118, -0.479401, 0.740791]}], \"expected\": {\"shape\": [5], \"data\": [-0.44031, -0.941138, 0.764809, 0.556869, 0.346693]}, \"input\": {\"shape\": [5], \"data\": [-0.62992, -0.960677, 0.906504, 0.360902, -0.026824]}}, \"legacy.Highway.2\": {\"weights\": [{\"shape\": [4, 4], \"data\": [-0.184626, -0.889268, 0.57707, -0.42539, -0.099299, -0.392175, 0.052799, 0.247624, 0.553551, 0.372483, 0.961878, 0.201632, 0.627937, 0.41729, -0.944931, 0.808534]}, {\"shape\": [4, 4], \"data\": [-0.498153, -0.907808, 0.353632, -0.913061, -0.767153, 0.207731, -0.618139, 0.337031, 0.834896, -0.16244, -0.33548, -0.433933, -0.627435, -0.365779, -0.037663, -0.860959]}], \"expected\": {\"shape\": [4], \"data\": [0.482075, -0.04199, 0.593448, 0.031503]}, \"input\": {\"shape\": [4], \"data\": [0.409965, -0.370646, 0.490565, -0.203574]}}}\n"
     ]
    }
   ],
   "source": [
    "print(json.dumps(DATA))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [default]",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
